/*
 * Copyright (c) 2021 yedf. All rights reserved.
 * Use of this source code is governed by a BSD-style
 * license that can be found in the LICENSE file.
 */

package dtmimp

import (
    "database/sql"
    "encoding/json"
    "errors"
    "fmt"
    "net/http"
    "os"
    "runtime"
    "strconv"
    "strings"
    "sync"
    "time"

    "github.com/dtm-labs/dtm/dtmcli/logger"
    "github.com/go-resty/resty/v2"
)

// Logf an alias of Infof
// Deprecated: use logger.Errorf
var Logf = logger.Infof

// LogRedf an alias of Errorf
// Deprecated: use logger.Errorf
var LogRedf = logger.Errorf

// FatalIfError fatal if error is not nil
// Deprecated: use logger.FatalIfError
var FatalIfError = logger.FatalIfError

// LogIfFatalf fatal if cond is true
// Deprecated: use logger.FatalfIf
var LogIfFatalf = logger.FatalfIf

// AsError wrap a panic value as an error
// interface 转 error
func AsError(x interface{}) error {
    logger.Errorf("panic wrapped to error: '%v'", x)
    if e, ok := x.(error); ok {
        return e
    }
    return fmt.Errorf("%v", x)
}

// P2E panic to error
func P2E(perr *error) {
    if x := recover(); x != nil {
        *perr = AsError(x)
    }
}

// E2P error to panic
func E2P(err error) {
    if err != nil {
        panic(err)
    }
}

// CatchP catch panic to error
func CatchP(f func()) (rerr error) {
    defer P2E(&rerr)
    f()
    return nil
}

// PanicIf name is clear
func PanicIf(cond bool, err error) {
    if cond {
        panic(err)
    }
}

// MustAtoi is string to int
func MustAtoi(s string) int {
    r, err := strconv.Atoi(s)
    if err != nil {
        E2P(errors.New("convert to int error: " + s))
    }
    return r
}

// OrString return the first not empty string
func OrString(ss ...string) string {
    for _, s := range ss {
        if s != "" {
            return s
        }
    }
    return ""
}

// If ternary operator
// 相当于 三目运算 b ? 1: 2
func If(condition bool, trueObj interface{}, falseObj interface{}) interface{} {
    if condition {
        return trueObj
    }
    return falseObj
}

// MustMarshal checked version for marshal
func MustMarshal(v interface{}) []byte {
    b, err := json.Marshal(v)
    E2P(err)
    return b
}

// MustMarshalString string version of MustMarshal
func MustMarshalString(v interface{}) string {
    return string(MustMarshal(v))
}

// MustUnmarshal checked version for unmarshal
func MustUnmarshal(b []byte, obj interface{}) {
    err := json.Unmarshal(b, obj)
    E2P(err)
}

// MustUnmarshalString string version of MustUnmarshal
func MustUnmarshalString(s string, obj interface{}) {
    MustUnmarshal([]byte(s), obj)
}

// MustRemarshal marshal and unmarshal, and check error
func MustRemarshal(from interface{}, to interface{}) {
    b, err := json.Marshal(from)
    E2P(err)
    err = json.Unmarshal(b, to)
    E2P(err)
}

// GetFuncName get current call func name
func GetFuncName() string {
    pc, _, _, _ := runtime.Caller(1)
    nm := runtime.FuncForPC(pc).Name()
    return nm[strings.LastIndex(nm, ".")+1:]
}

// MayReplaceLocalhost when run in docker compose, change localhost to host.docker.internal for accessing host network
func MayReplaceLocalhost(host string) string {
    if os.Getenv("IS_DOCKER") != "" {
        return strings.Replace(host, "localhost", "host.docker.internal", 1)
    }
    return host
}

var sqlDbs sync.Map

// PooledDB get pooled sql.DB
func PooledDB(conf DBConf) (*sql.DB, error) {
    dsn := GetDsn(conf)
    db, ok := sqlDbs.Load(dsn)
    if !ok {
        db2, err := StandaloneDB(conf)
        if err != nil {
            return nil, err
        }
        db = db2
        sqlDbs.Store(dsn, db)
    }
    return db.(*sql.DB), nil
}

// StandaloneDB get a standalone db instance
func StandaloneDB(conf DBConf) (*sql.DB, error) {
    dsn := GetDsn(conf)
    logger.Infof("opening standalone %s: %s", conf.Driver, strings.Replace(dsn, conf.Password, "****", 1))
    return sql.Open(conf.Driver, dsn)
}

// DBExec use raw db to exec
func DBExec(db DB, sql string, values ...interface{}) (affected int64, rerr error) {
    if sql == "" {
        return 0, nil
    }
    began := time.Now()
    sql = GetDBSpecial().GetPlaceHoldSQL(sql)
    r, rerr := db.Exec(sql, values...)
    used := time.Since(began) / time.Millisecond
    if rerr == nil {
        affected, rerr = r.RowsAffected()
        logger.Debugf("used: %d ms affected: %d for %s %v", used, affected, sql, values)
    } else {
        logger.Errorf("used: %d ms exec error: %v for %s %v", used, rerr, sql, values)
    }
    return
}

// GetDsn get dsn from map config
func GetDsn(conf DBConf) string {
    host := MayReplaceLocalhost(conf.Host)
    driver := conf.Driver
    dsn := map[string]string{
        "mysql": fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local&interpolateParams=true",
            conf.User, conf.Password, host, conf.Port, ""),
        "postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' port=%d sslmode=disable",
            host, conf.User, conf.Password, "", conf.Port),
    }[driver]
    PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
    return dsn
}

// RespAsErrorCompatible translate a resty response to error
// compatible with version < v1.10
// 根据状态码或者返回结果设置 error
func RespAsErrorCompatible(resp *resty.Response) error {
    code := resp.StatusCode()
    str := resp.String()
    if code == http.StatusTooEarly || strings.Contains(str, ResultOngoing) {
        return fmt.Errorf("%s. %w", str, ErrOngoing)
    } else if code == http.StatusConflict || strings.Contains(str, ResultFailure) {
        return fmt.Errorf("%s. %w", str, ErrFailure)
    } else if code != http.StatusOK {
        return errors.New(str)
    }
    return nil
}

// DeferDo a common defer do used in dtmcli/dtmgrpc
func DeferDo(rerr *error, success func() error, fail func() error) {
    defer func() {
        if x := recover(); x != nil {
            _ = fail()
            panic(x)
        } else if *rerr != nil {
            _ = fail()
        } else {
            *rerr = success()
        }
    }()
}
